{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import statsmodels.formula.api as smf\n",
    "import matplotlib.pyplot as plt\n",
    "import fastreg.linear as frl\n",
    "import fastreg.general as frg\n",
    "import fastreg.testing as frt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>one</th>\n",
       "      <th>id1</th>\n",
       "      <th>id2</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>yhat0</th>\n",
       "      <th>yhat</th>\n",
       "      <th>y0</th>\n",
       "      <th>y</th>\n",
       "      <th>Ep0</th>\n",
       "      <th>Ep</th>\n",
       "      <th>p0</th>\n",
       "      <th>p</th>\n",
       "      <th>pz</th>\n",
       "      <th>Eb0</th>\n",
       "      <th>Eb</th>\n",
       "      <th>b0</th>\n",
       "      <th>b</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>15</td>\n",
       "      <td>-0.373497</td>\n",
       "      <td>0.847802</td>\n",
       "      <td>-0.054538</td>\n",
       "      <td>0.028462</td>\n",
       "      <td>-0.643401</td>\n",
       "      <td>-0.948844</td>\n",
       "      <td>0.946922</td>\n",
       "      <td>1.028871</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.486369</td>\n",
       "      <td>0.507115</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>30</td>\n",
       "      <td>0.533385</td>\n",
       "      <td>-0.505209</td>\n",
       "      <td>0.218989</td>\n",
       "      <td>0.384989</td>\n",
       "      <td>1.245903</td>\n",
       "      <td>1.122130</td>\n",
       "      <td>1.244818</td>\n",
       "      <td>1.469598</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0.554530</td>\n",
       "      <td>0.595076</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>24</td>\n",
       "      <td>-1.403897</td>\n",
       "      <td>0.738269</td>\n",
       "      <td>-0.694684</td>\n",
       "      <td>-0.558684</td>\n",
       "      <td>-0.406558</td>\n",
       "      <td>-0.621899</td>\n",
       "      <td>0.499232</td>\n",
       "      <td>0.571961</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0.332992</td>\n",
       "      <td>0.363852</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>74</td>\n",
       "      <td>-1.150902</td>\n",
       "      <td>0.978757</td>\n",
       "      <td>-0.494790</td>\n",
       "      <td>-0.108790</td>\n",
       "      <td>0.129366</td>\n",
       "      <td>-1.258796</td>\n",
       "      <td>0.609699</td>\n",
       "      <td>0.896919</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.378766</td>\n",
       "      <td>0.472829</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>45</td>\n",
       "      <td>1.456634</td>\n",
       "      <td>2.104154</td>\n",
       "      <td>1.294811</td>\n",
       "      <td>1.531811</td>\n",
       "      <td>2.023731</td>\n",
       "      <td>2.275766</td>\n",
       "      <td>3.650306</td>\n",
       "      <td>4.626548</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.784960</td>\n",
       "      <td>0.822271</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   one  id1  id2        x1        x2     yhat0      yhat        y0         y  \\\n",
       "0    1    4   15 -0.373497  0.847802 -0.054538  0.028462 -0.643401 -0.948844   \n",
       "1    1    8   30  0.533385 -0.505209  0.218989  0.384989  1.245903  1.122130   \n",
       "2    1    8   24 -1.403897  0.738269 -0.694684 -0.558684 -0.406558 -0.621899   \n",
       "3    1    8   74 -1.150902  0.978757 -0.494790 -0.108790  0.129366 -1.258796   \n",
       "4    1    6   45  1.456634  2.104154  1.294811  1.531811  2.023731  2.275766   \n",
       "\n",
       "        Ep0        Ep  p0  p  pz       Eb0        Eb  b0  b  \n",
       "0  0.946922  1.028871   2  0   0  0.486369  0.507115   1  1  \n",
       "1  1.244818  1.469598   1  2   2  0.554530  0.595076   1  1  \n",
       "2  0.499232  0.571961   0  2   2  0.332992  0.363852   0  0  \n",
       "3  0.609699  0.896919   0  1   1  0.378766  0.472829   0  1  \n",
       "4  3.650306  4.626548   4  0   0  0.784960  0.822271   1  1  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = frt.dataset(N=5_000_000, K1=10, K2=100, seed=89320432)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normal OLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.85 s, sys: 1.88 s, total: 4.74 s\n",
      "Wall time: 1.64 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Intercept    0.000533\n",
       "x1           0.599798\n",
       "x2           0.199904\n",
       "dtype: float64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time smf.ols('y0 ~ x1 + x2', data=data).fit().params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 551 ms, sys: 538 ms, total: 1.09 s\n",
      "Wall time: 457 ms\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>coeff</th>\n",
       "      <th>stderr</th>\n",
       "      <th>low95</th>\n",
       "      <th>high95</th>\n",
       "      <th>pvalue</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>one</td>\n",
       "      <td>0.000533</td>\n",
       "      <td>0.000447</td>\n",
       "      <td>-0.000344</td>\n",
       "      <td>0.001410</td>\n",
       "      <td>0.116722</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>x1</td>\n",
       "      <td>0.599798</td>\n",
       "      <td>0.000447</td>\n",
       "      <td>0.598921</td>\n",
       "      <td>0.600675</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>x2</td>\n",
       "      <td>0.199904</td>\n",
       "      <td>0.000447</td>\n",
       "      <td>0.199027</td>\n",
       "      <td>0.200780</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        coeff    stderr     low95    high95    pvalue\n",
       "one  0.000533  0.000447 -0.000344  0.001410  0.116722\n",
       "x1   0.599798  0.000447  0.598921  0.600675  0.000000\n",
       "x2   0.199904  0.000447  0.199027  0.200780  0.000000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time frl.ols(y='y0', x=['x1', 'x2'], data=data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sparse OLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.59 s, sys: 1.59 s, total: 7.18 s\n",
      "Wall time: 6.79 s\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>coeff</th>\n",
       "      <th>stderr</th>\n",
       "      <th>low95</th>\n",
       "      <th>high95</th>\n",
       "      <th>pvalue</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>one</td>\n",
       "      <td>0.004118</td>\n",
       "      <td>0.004655</td>\n",
       "      <td>-0.005007</td>\n",
       "      <td>0.013242</td>\n",
       "      <td>0.188213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>x1</td>\n",
       "      <td>0.600189</td>\n",
       "      <td>0.000447</td>\n",
       "      <td>0.599312</td>\n",
       "      <td>0.601065</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>x2</td>\n",
       "      <td>0.199648</td>\n",
       "      <td>0.000447</td>\n",
       "      <td>0.198772</td>\n",
       "      <td>0.200524</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id1=1</td>\n",
       "      <td>-0.001416</td>\n",
       "      <td>0.002001</td>\n",
       "      <td>-0.005337</td>\n",
       "      <td>0.002505</td>\n",
       "      <td>0.239536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id1=2</td>\n",
       "      <td>0.001605</td>\n",
       "      <td>0.001999</td>\n",
       "      <td>-0.002313</td>\n",
       "      <td>0.005524</td>\n",
       "      <td>0.211000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id2=95</td>\n",
       "      <td>0.477899</td>\n",
       "      <td>0.006317</td>\n",
       "      <td>0.465517</td>\n",
       "      <td>0.490282</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id2=96</td>\n",
       "      <td>0.473080</td>\n",
       "      <td>0.006309</td>\n",
       "      <td>0.460715</td>\n",
       "      <td>0.485445</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id2=97</td>\n",
       "      <td>0.482564</td>\n",
       "      <td>0.006305</td>\n",
       "      <td>0.470207</td>\n",
       "      <td>0.494921</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id2=98</td>\n",
       "      <td>0.481532</td>\n",
       "      <td>0.006305</td>\n",
       "      <td>0.469174</td>\n",
       "      <td>0.493891</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>id2=99</td>\n",
       "      <td>0.482068</td>\n",
       "      <td>0.006314</td>\n",
       "      <td>0.469693</td>\n",
       "      <td>0.494443</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>111 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           coeff    stderr     low95    high95    pvalue\n",
       "one     0.004118  0.004655 -0.005007  0.013242  0.188213\n",
       "x1      0.600189  0.000447  0.599312  0.601065  0.000000\n",
       "x2      0.199648  0.000447  0.198772  0.200524  0.000000\n",
       "id1=1  -0.001416  0.002001 -0.005337  0.002505  0.239536\n",
       "id1=2   0.001605  0.001999 -0.002313  0.005524  0.211000\n",
       "...          ...       ...       ...       ...       ...\n",
       "id2=95  0.477899  0.006317  0.465517  0.490282  0.000000\n",
       "id2=96  0.473080  0.006309  0.460715  0.485445  0.000000\n",
       "id2=97  0.482564  0.006305  0.470207  0.494921  0.000000\n",
       "id2=98  0.481532  0.006305  0.469174  0.493891  0.000000\n",
       "id2=99  0.482068  0.006314  0.469693  0.494443  0.000000\n",
       "\n",
       "[111 rows x 5 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time frl.ols(y='y', x=['x1', 'x2'], fe=['id1', 'id2'], data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 12.9 s, sys: 8 s, total: 20.9 s\n",
      "Wall time: 6.87 s\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>coeff</th>\n",
       "      <th>stderr</th>\n",
       "      <th>low95</th>\n",
       "      <th>high95</th>\n",
       "      <th>pvalue</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>one</td>\n",
       "      <td>0.256540</td>\n",
       "      <td>3.148679e-07</td>\n",
       "      <td>0.256539</td>\n",
       "      <td>0.256541</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>x1</td>\n",
       "      <td>0.600186</td>\n",
       "      <td>4.517726e-04</td>\n",
       "      <td>0.599301</td>\n",
       "      <td>0.601072</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>x2</td>\n",
       "      <td>0.199648</td>\n",
       "      <td>4.703627e-04</td>\n",
       "      <td>0.198726</td>\n",
       "      <td>0.200570</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        coeff        stderr     low95    high95  pvalue\n",
       "one  0.256540  3.148679e-07  0.256539  0.256541     0.0\n",
       "x1   0.600186  4.517726e-04  0.599301  0.601072     0.0\n",
       "x2   0.199648  4.703627e-04  0.198726  0.200570     0.0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time frl.ols(y='y', x=['x1', 'x2'], absorb=('id1', 'id2'), data=data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Poisson"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0: loss = 0.5313539505004883\n",
      "  1: loss = 0.5309544801712036\n",
      "  2: loss = 0.5309544801712036\n"
     ]
    }
   ],
   "source": [
    "%time frg.poisson(y='p', x=['x1', 'x2'], data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0: loss = 0.516063928604126\n",
      "  1: loss = 0.5147619843482971\n",
      "  2: loss = 0.5147106647491455\n"
     ]
    }
   ],
   "source": [
    "%time frg.poisson(y='p', x=['x1', 'x2'], fe=['id1', 'id2'], data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7",
   "language": "python",
   "name": "python3.7"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
